{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "class Model4_6(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_6, self).__init__() \n",
    "        self.cnn = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.maxpool=nn.MaxPool2d(kernel_size=2)\n",
    "        \n",
    "        self.fc1 = nn.Linear(32*14*14, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.cnn(x)\n",
    "        out = self.relu(out)\n",
    "        out = self.maxpool(out)\n",
    "        out=out.view(out.size(0), -1)\n",
    "        out = self.fc1(out)\n",
    "        return out\n",
    "\n",
    "model4_6 = Model4_6() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model4_7(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model4_7, self).__init__() \n",
    "        self.layer1 = nn.Sequential(nn.Conv2d(1,16, kernel_size=5, stride=1, padding=2),\n",
    "             nn.ReLU(),\n",
    "             nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),\n",
    "             nn.ReLU(),\n",
    "             nn.MaxPool2d(kernel_size=2, stride=2))        \n",
    "        self.fc1 = nn.Linear(32*7*7, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out= self.layer1(x)\n",
    "        out= self.layer2(out)\n",
    "        out=out.view(out.size(0), -1)\n",
    "        out = self.fc1(out)\n",
    "        return out\n",
    "\n",
    "model4_7 = Model4_7() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "batchSize = 100\n",
    "trainSet = torchvision.datasets.MNIST(root='./data', train = True, transform=transforms.ToTensor(), download=True)\n",
    "trainLoader = torch.utils.data.DataLoader(dataset=trainSet, batch_size=batchSize, shuffle = True)\n",
    "testSet = torchvision.datasets.MNIST(root='./data', train = False, transform=transforms.ToTensor(), download=True)\n",
    "testLoader = torch.utils.data.DataLoader(dataset=testSet, batch_size=batchSize, shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "import time\n",
    "\n",
    "def benchmark(trainLoader, model, epochs=1, lr=0.01):\n",
    "    model.__init__()\n",
    "    start=time.time()\n",
    "    optimiser = optim.SGD(model.parameters(), lr=lr)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    for epoch in range(epochs):\n",
    "        for i, (images, labels) in enumerate(trainLoader):\n",
    "            optimiser.zero_grad()\n",
    "            outputs = model(images)\n",
    "            loss = criterion(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimiser.step()\n",
    "    print('Accuracy: {0:.4f}'.format(accuracy(testLoader,model)))\n",
    "    print('Training time: {0:.2f}'.format(time.time() - start))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def accuracy(testLoader,model):    \n",
    "    correct, total = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for data in testLoader:\n",
    "            images, labels = data\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()      \n",
    "    return(correct / total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 convolutional layers\n",
      "Accuracy: 0.9654\n",
      "Training time: 50.90\n",
      "2 convolutional layers\n",
      "Accuracy: 0.9743\n",
      "Training time: 84.47\n"
     ]
    }
   ],
   "source": [
    "print('1 convolutional layers')\n",
    "benchmark(trainLoader,model4_6, lr=0.1)\n",
    "print('2 convolutional layers')\n",
    "benchmark(trainLoader,model4_7, lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 1, 5, 5])\n",
      "torch.Size([16])\n",
      "torch.Size([32, 16, 5, 5])\n",
      "torch.Size([32])\n",
      "torch.Size([10, 1568])\n",
      "torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "parameters=list((model.parameters()))\n",
    "for parameter in parameters:\n",
    "    print(parameter.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1568"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
